# Copyright (c) Facebook, Inc. and its affiliates.
import math
import torch
import torch.nn.functional as F
from torch import nn

from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule, auto_fp16
from mmdet.models.builder import NECKS

from typing import List, Optional, Dict, Any


__all__ = ["DetectronFPN"]


@NECKS.register_module()
class DetectronFPN(BaseModule):
    """
    This module implements :paper:`FPN`.
    It creates pyramid features built on top of some input feature maps.
    """

    _fuse_type: str

    def __init__(
        self, 
        in_indices: List[int], 
        out_indices: List[int], 
        in_channels: List[int],
        out_channels: int, 
        start_level: int,
        conv_cfg: Optional[Dict[str, Any]] = dict(type="Conv2d"),
        norm_cfg: Optional[Dict[str, Any]] = dict(type="BN2d"),
        act_cfg: Optional[Dict[str, Any]] = None,
        fuse_type: Optional[str] = "sum"
    ):
        """
        Args:
            bottom_up (Backbone): module representing the bottom up subnetwork.
                Must be a subclass of :class:`Backbone`. The multi-scale feature
                maps generated by the bottom up network, and listed in `in_features`,
                are used to generate FPN levels.
            in_features (list[str]): names of the input feature maps coming
                from the backbone to which FPN is attached. For example, if the
                backbone produces ["res2", "res3", "res4"], any *contiguous* sublist
                of these may be used; order must be from high to low resolution.
            out_channels (int): number of channels in the output feature maps.
            norm (str): the normalization to use.
            top_block (nn.Module or None): if provided, an extra operation will
                be performed on the output of the last (smallest resolution)
                FPN output, and the result will extend the result list. The top_block
                further downsamples the feature map. It must have an attribute
                "num_levels", meaning the number of extra FPN levels added by
                this block, and "in_feature", which is a string representing
                its input feature (e.g., p5).
            fuse_type (str): types for fusing the top down features and the lateral
                ones. It can be "sum" (default), which sums up element-wise; or "avg",
                which takes the element-wise mean of the two.
        """
        super(DetectronFPN, self).__init__()

        # Feature map strides and channels from the bottom up network (e.g. ResNet)
        lateral_convs = []
        output_convs = []

        use_bias = norm_cfg is None
        stage = start_level
        for idx, in_channel in enumerate(in_channels):
            lateral_conv = ConvModule(
                in_channel, 
                out_channels, 
                kernel_size=1, 
                bias=use_bias, 
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=act_cfg
            )
            output_conv = ConvModule(
                out_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=act_cfg
            )
            self.add_module("fpn_lateral{}".format(stage), lateral_conv)
            self.add_module("fpn_output{}".format(stage), output_conv)
            stage += 1

            lateral_convs.append(lateral_conv)
            output_convs.append(output_conv)
        # Place convs into top-down order (from low to high resolution)
        # to make the top-down computation in forward clearer.
        self.lateral_convs = lateral_convs[::-1]
        self.output_convs = output_convs[::-1]
        self.in_indices = tuple(in_indices)
        self.out_indices = tuple(out_indices)
        assert fuse_type in {"avg", "sum"}
        self._fuse_type = fuse_type


    def forward(self, bottom_up_features):
        """
        Args:
            input (dict[str->Tensor]): mapping feature map name (e.g., "res5") to
                feature map tensor for each feature level in high to low resolution order.
        Returns:
            dict[str->Tensor]:
                mapping from feature map name to FPN feature map tensor
                in high to low resolution order. Returned feature names follow the FPN
                paper convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
                ["p2", "p3", ..., "p6"].
        """

        results = []
        prev_features = self.lateral_convs[0](bottom_up_features[self.in_indices[-1]])
        results.append(self.output_convs[0](prev_features))

        # Reverse feature maps into top-down order (from low to high resolution)
        for idx, (lateral_conv, output_conv) in enumerate(
            zip(self.lateral_convs, self.output_convs)
        ):
            # Slicing of ModuleList is not supported https://github.com/pytorch/pytorch/issues/47336
            # Therefore we loop over all modules but skip the first one
            if idx > 0:
                features = self.in_indices[-idx - 1]
                features = bottom_up_features[features]
                top_down_features = F.interpolate(prev_features, scale_factor=2.0, mode="nearest")
                lateral_features = lateral_conv(features)
                prev_features = lateral_features + top_down_features
                if self._fuse_type == "avg":
                    prev_features /= 2
                results.insert(0, output_conv(prev_features))
        return [results[x] for x in sorted(self.out_indices)]
